import torch
from logging import getLogger

from SMTWTModel import SMTWTModel as Model
from SWTWTEnv import PFSPEnv as Env
from Data_generator import SMTWTDataset
from torch.utils.data import DataLoader

from torch.optim import Adam as Optimizer
from torch.optim.lr_scheduler import MultiStepLR as Scheduler

from utils import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )

class PFSPTrainer:
    def __init__(self,
                 env_params,
                 model_params,
                 optimizer_params,
                 trainer_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.optimizer_params = optimizer_params
        self.trainer_params = trainer_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()
        self.result_log = LogData()

        # env params
        self.n_jobs = self.env_params['job_cnt']
        self.pomo_size = self.env_params['pomo_size']
        self.latent_cont_dim = self.model_params['latent_cont_size']
        self.latent_disc_dim = self.model_params['latent_disc_size']

        # trainer params
        self.train_batch_size = trainer_params['train_batch_size']
        self.accumulation_step = trainer_params['accumulation_step']

        # cuda setting
        USE_CUDA = self.trainer_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.trainer_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_dtype('torch.FloatTensor')

        ##### Main Components
        # Set model
        self.model = Model(**self.model_params)
        self.env = Env(**self.env_params)

        # Reset
        self.start_epoch = 1
        # Load pre-trained model
        model_load = trainer_params['model_load']
        if model_load['enable']:
            checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
            checkpoint = torch.load(checkpoint_fullname, map_location=device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            if not model_load['load_model_only']:
                self.start_epoch=1+model_load['epoch']
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.scheduler.last_epoch = model_load['epoch']-1
            self.logger.info('Saved Model Loaded !!')

        self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer'])
        self.scheduler = Scheduler(self.optimizer, **self.optimizer_params['scheduler'])
        self.time_estimator = TimeEstimator()

    def run(self):
        # Start train
        self.time_estimator.reset(self.start_epoch) # Time reset
        self.logger.info('Start model training...')

        for epoch in range(self.start_epoch, self.trainer_params['epochs']+1):
            self.logger.info('=================================================================')

            score_AM = AverageMeter()
            loss_AM = AverageMeter()

            train_dataset = SMTWTDataset(self.trainer_params['train_episodes'], self.n_jobs).data
            train_dataloader = DataLoader(train_dataset, batch_size=self.train_batch_size, shuffle=False, generator=torch.Generator(device=device))

            loop_cnt=0
            for i, train_data in enumerate(train_dataloader):

                latent_c_var = torch.empty(self.train_batch_size, self.env.pomo_size, self.latent_cont_dim ).uniform_(-1, 1)

                latent_d_var = torch.zeros((self.train_batch_size, self.env.pomo_size, self.latent_disc_dim), dtype=torch.float32)
                one_hot_idx = torch.randint(0, self.latent_disc_dim, (self.train_batch_size, self.env.pomo_size), dtype=torch.long)
                latent_d_var[torch.arange(self.train_batch_size).unsqueeze(1), torch.arange(self.env.pomo_size).unsqueeze(0), one_hot_idx] = 1

                latent_var = torch.cat([latent_d_var, latent_c_var], dim=-1)

                self.model.train()
                self.env.load_problems_manual(train_data)
                reset_state, _, _ = self.env.reset()

                self.model.pre_forward(reset_state, latent_var)

                selected_list = torch.zeros(size=(self.train_batch_size, self.pomo_size, 0), dtype=torch.long)
                prob_list = torch.zeros(size=(self.train_batch_size, self.env.pomo_size, 0))

                state, reward, done = self.env.pre_step()

                while not done:
                    selected, prob = self.model(state, selected_list)
                    state, reward, done = self.env.step(selected)

                    prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)
                    selected_list = torch.cat((selected_list, selected[:, :, None]), dim=2)

                _, argmax = reward.max(dim=1)
                max_reward = reward.max(dim=1, keepdim=True).values 
                mean_reward = reward.mean(dim=1, keepdim=True) 
                pomo_variance = reward.var(dim=1, keepdim=True, unbiased=False) 
                loss_weight = torch.abs((max_reward - mean_reward)) / torch.sqrt(pomo_variance + 1e-8)
                
                probs = prob_list[torch.arange(self.train_batch_size), argmax, :] 
                probs = probs[:,:-1]
                log_probs = torch.log(probs+ 1e-8)

                batch_loss = log_probs*loss_weight
                loss = -batch_loss.mean()

                loss = loss/self.accumulation_step
                loss.backward()
                if (i+1)%self.accumulation_step==0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()

                max_pomo_reward, _ = reward.max(dim=1)
                score_mean = -max_pomo_reward.float().mean()

                score_AM.update(score_mean.item(), train_data.size(0))
                loss_AM.update(loss.item()*self.accumulation_step, train_data.size(0)) 

                del train_data
                if epoch == self.start_epoch:
                    loop_cnt += 1
                    if loop_cnt <= 10:
                        self.logger.info('Epoch {:3d}: Train {:3d}/{:3d}({:1.1f}%)  Score: {:.4f},  Loss: {:.4f}'
                                        .format(epoch, i*self.train_batch_size, self.trainer_params['train_episodes'], 100. * i*self.train_batch_size / self.trainer_params['train_episodes'],
                                                score_AM.avg, loss_AM.avg))
            ############################
            # Logs & Checkpoint
            ############################
            self.scheduler.step()
            
            self.result_log.append('train_score', epoch, score_AM.avg)
            self.result_log.append('train_loss', epoch, loss_AM.avg)
            
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs'])

            self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format(
                epoch, self.trainer_params['epochs'], elapsed_time_str, remain_time_str))
            
            self.logger.info('Epoch {:3d}:  Trian set Score: {:.4f},  CE Loss: {:.4f}'.format(epoch, score_AM.avg, loss_AM.avg))
                
            all_done = (epoch == self.trainer_params['epochs'])

            
            torch.save(self.model.state_dict(), '{}/checkpoint-{}.pt'.format(self.result_folder, epoch))
            checkpoint_dict = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict()
                }
            torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch))
            ################ Check done
            if all_done:
                self.logger.info(" *** Training Done *** ")
                self.logger.info("Now, printing log array...")
                util_print_log_array(self.logger, self.result_log)